import matplotlib.pyplot as plt
import seaborn as sns
import os
import torch
from utils.colors import colors_dictionary

import numpy as np 
results_dir = "results/"

def CurvatureGraph(Curvature1,Curvature2,Curvature3,Curvature4,Curvature5,Curvature6,edge_index,colorlabels,labels ,title:str):
    FigCurvature,axCurvature = plt.subplot_mosaic([['h1','h2','h3','h4'],['.','.','.','.'],['D','D','D','D'],['.','.','.','.'],['h5','h5','h6','h6']],
                                  figsize=(5.5,2.85),gridspec_kw = {'width_ratios':[1,1,1,1],'height_ratios': [0.7,0.23,0.5,0.23,0.65],'hspace':0},dpi = 450)

    nh1, bins, patches1 = axCurvature['h1'].hist(Curvature1[edge_index],bins = np.linspace(-2,5,41),linewidth=0.3,edgecolor = 'black',color = colors_dictionary[colorlabels[0]],label = labels[0],rwidth=0.7)
    #axCurvature['h1'].plot(2*[min(Curvature1[edge_index[0],edge_index[1]])],[0,100],color = 'black',linestyle = '--',linewidth = 0.2)
    ax1 =  axCurvature['h1'].inset_axes([0.55, 0.55 , 0.4, 0.4])
    nax1, bins, patches = ax1.hist(Curvature1[edge_index][torch.where(Curvature1[edge_index] < -0.5)],bins = np.linspace(-2,5,41),color = colors_dictionary[colorlabels[0]],linewidth=0.1,edgecolor = 'black')
    
    ax1.tick_params(labelsize=6)
    nh2, bins, patches2 = axCurvature['h2'].hist(Curvature2[edge_index],bins = np.linspace(-2,5,41),linewidth=0.3,edgecolor = 'black',color = colors_dictionary[colorlabels[1]],label = labels[1],rwidth=0.7)
    #axCurvature['h2'].plot(2*[min(Curvature2[edge_index[0],edge_index[1]])],[0,100],color = 'black',linestyle = '--',linewidth = 0.2)
    
    ax2 =  axCurvature['h2'].inset_axes([0.55, 0.55 , 0.4, 0.4])
    nax2, bins, patches = ax2.hist(Curvature2[edge_index][torch.where(Curvature2[edge_index] < -0.5)],bins = np.linspace(-2,5,41),color = colors_dictionary[colorlabels[1]],linewidth=0.1,edgecolor = 'black')
    ax2.set_xlim(-2,-0.5)
    ax2.tick_params(labelsize=6)

    nh3, bins, patches3 = axCurvature['h3'].hist(Curvature3[edge_index],bins = np.linspace(-2,5,41),linewidth=0.3,edgecolor = 'black',color = colors_dictionary[colorlabels[2]],label = labels[2],rwidth=0.7)
    #axCurvature['h3'].plot(2*[min(Curvature3[edge_index[0],edge_index[1]])],[0,100],color = 'black',linestyle = '--',linewidth = 0.2)
    
    ax3 =  axCurvature['h3'].inset_axes([0.55, 0.55 , 0.4, 0.4])
    nax3, bins, patches = ax3.hist(Curvature3[edge_index][torch.where(Curvature3[edge_index] < -0.5)],bins = np.linspace(-2,5,41),color = colors_dictionary[colorlabels[2]],linewidth=0.1,edgecolor = 'black')
    ax3.set_xlim(-2,-0.5)
    ax3.tick_params(labelsize=6)
    nh4, bins, patches4 = axCurvature['h4'].hist(Curvature4[edge_index],bins = np.linspace(-2,5,41),linewidth=0.3,edgecolor = 'black',color = colors_dictionary[colorlabels[3]],label = labels[3],rwidth=0.7)
    #axCurvature['h4'].plot(2*[min(Curvature4[edge_index[0],edge_index[1]])],[0,100],color = 'black',linestyle = '--',linewidth = 0.2)

    ax4 =  axCurvature['h4'].inset_axes([0.55, 0.55 , 0.4, 0.4])
    nax4, bins, patches = ax4.hist(Curvature4[edge_index][torch.where(Curvature4[edge_index] < -0.5)],bins = np.linspace(-2,5,41),color = colors_dictionary[colorlabels[3]],linewidth=0.1,edgecolor = 'black')
    ax4.set_xlim(-2,-0.5)
    ax4.tick_params(labelsize=6)

    
    #axCurvature['h4'].plot(2*[min(Curvature4[edge_index[0],edge_index[1]])],[0,100],color = 'black',linestyle = '--',linewidth = 0.2)

    if title == "Chameleon":
         n, bins, patches6 = axCurvature['h6'].hist(Curvature6[edge_index],bins = np.linspace(-110,9000,800),
                                              color = colors_dictionary[colorlabels[5]],rwidth=0.7,
                                              label = labels[5])
         nh4, bins, patches5 = axCurvature['h5'].hist(Curvature5[edge_index],bins = np.linspace(-110,80,180),
                                              color = colors_dictionary[colorlabels[4]],rwidth=0.7,
                                              label = labels[4])
    else:
         n, bins, patches6 = axCurvature['h6'].hist(Curvature6[edge_index],bins = np.linspace(-110,10,100),
                                              linewidth=0.5,edgecolor = 'black',color = colors_dictionary[colorlabels[5]],rwidth=0.7,
                                              label = labels[5])
         nh4, bins, patches5 = axCurvature['h5'].hist(Curvature5[edge_index],bins = np.linspace(-110,10,100),
                                              linewidth=0.5,edgecolor = 'black',color = colors_dictionary[colorlabels[4]],rwidth=0.7,
                                              label = labels[4])
    patch = [patches1[0],patches2[0],patches3[0],patches4[0],patches5[0],patches6[0]]
    labs = [l.get_label() for l in patch]
    axCurvature['h5'].legend(patch, labs, loc="lower left",fontsize = 7, ncol=6,bbox_to_anchor=(-0.025, -0.8))
    #axCurvature['h4'].plot(2*[min(Curvature4[edge_index[0],edge_index[1]])],[0,100],color = 'black',linestyle = '--',linewidth = 0.2)


    for i in range(4):
        axCurvature['h'+str(i+1)].set_xticks([-2,0,5])
        if title == "Chameleon":
            axCurvature['h'+str(i+1)].set_xlim(-2.2,9)
        if title == "Cora" or title =="Citeseer":
            axCurvature['h'+str(i+1)].set_xlim(-2.2,6.5)
        


    for i in range(6):
        axCurvature['h'+str(i+1)].tick_params(labelsize=6)

    axCurvature['D'].tick_params(labelsize=6)
    axCurvature['D'].set_ylabel("Density",fontsize = 7)
    ax1.tick_params(labelsize=6)
    ax2.tick_params(labelsize=6)
    ax3.tick_params(labelsize=6)
    ax4.tick_params(labelsize=6)
    #n, bins, patches = axCurvature['h5'].hist(Curvature5[edge_index[0],edge_index[1]],color = 'cyan',rwidth=0.7)

    _= sns.kdeplot(data = Curvature1[edge_index], ax = axCurvature['D'],label =  labels[0],color = colors_dictionary[colorlabels[0]])
    _= sns.kdeplot(data = Curvature2[edge_index], ax = axCurvature['D'],label =  labels[1],color = colors_dictionary[colorlabels[1]])
    _= sns.kdeplot(data = Curvature3[edge_index], ax = axCurvature['D'],label =  labels[2],color = colors_dictionary[colorlabels[2]])
    _= sns.kdeplot(data = Curvature4[edge_index], ax = axCurvature['D'],label =  labels[3],color = colors_dictionary[colorlabels[3]])
    #_= sns.kdeplot(data = Curvature5[edge_index[0],edge_index[1]], ax = axCurvature['D'],label =  labels[4],color = 'cyan')
    
    

    axCurvature['h1'].sharey(axCurvature['h2'])
    axCurvature['h2'].sharey(axCurvature['h3'])
    axCurvature['h3'].sharey(axCurvature['h4'])

    axCurvature['h2'].set_yticklabels([])
    axCurvature['h3'].set_yticklabels([])
    axCurvature['h4'].set_yticklabels([])
    
    axCurvature['h1'].set_ylim(0,max([max(nh1),max(nh2),max(nh3),max(nh4)])+50)

    ax1.sharey(ax2)
    ax2.sharey(ax3)
    ax3.sharey(ax4)

    ax2.set_yticklabels([])
    ax3.set_yticklabels([])
    
    ax1.set_xlim(-2,-0.5)
    ax1.set_ylim(0,max([max(nax1),max(nax2),max(nax3),max(nax4)])+20)

    axCurvature['h5'].sharey(axCurvature['h6'])
    [label.set_visible(False) for label in axCurvature['h6'].get_yticklabels()]
    #axCurvature['h6'].set_yticklabels([])

    #_ = FigCurvature.suptitle(r"$\mathbf{"+title+"}$",fontsize = 8)
    axCurvature['h1'].text(0, 1.3,
                        horizontalalignment='left',
                        verticalalignment='top',
                        transform=axCurvature['h1'].transAxes,
                        s = r"$\mathbf{"+title+"}$" ,fontsize = 9)
    #_ = axCurvature['D'].legend(fontsize = 7, ncol=4)
    #_ = axCurvature['D'].set_xlabel('Curvature')
    _ = axCurvature['D'].tick_params(labelsize=6)

    FigCurvature.savefig(os.path.join(results_dir,"CurvaturesComparison_" + title + ".pdf"),format = 'pdf',bbox_inches='tight')

def CurvatureGraph_ownedge(Curvature1,edge_index1,Curvature2,edge_index2,Curvature3,edge_index3, title:str):
    FigCurvature,axCurvature = plt.subplot_mosaic([['A','B','C'],['D','D','D']],
                                  figsize=(7.08,3.00),gridspec_kw = {'width_ratios':[1,1,1],'height_ratios': [0.8,1]},dpi = 300)

    n, bins, patches = axCurvature['A'].hist(Curvature1[edge_index1[0],edge_index1[1]],bins = np.linspace(-2,5,41),color = 'blue',rwidth=0.7)
    n, bins, patches = axCurvature['B'].hist(Curvature2[edge_index2[0],edge_index2[1]],bins = np.linspace(-2,5,41),color = 'orange',rwidth=0.7)
    n, bins, patches = axCurvature['C'].hist(Curvature3[edge_index3[0],edge_index3[1]],bins = np.linspace(-2,5,41),color = 'green',rwidth=0.7)


    _= sns.kdeplot(data = Curvature1[edge_index1[0],edge_index1[1]], ax = axCurvature['D'],label =  'Curvature 1',color = 'blue')
    _= sns.kdeplot(data = Curvature2[edge_index2[0],edge_index2[1]], ax = axCurvature['D'],label =  'Curvature 2',color = 'orange')
    _= sns.kdeplot(data = Curvature3[edge_index3[0],edge_index3[1]], ax = axCurvature['D'],label =  'Curvature 3',color = 'green')

    axCurvature['A'].sharey(axCurvature['B'])
    axCurvature['B'].sharey(axCurvature['C'])
    axCurvature['B'].set_yticklabels([])
    axCurvature['C'].set_yticklabels([])


    _ = axCurvature['B'].set_title(title)
    _ = axCurvature['D'].legend()
    _ = axCurvature['D'].set_xlabel('Curvature')

    plt.show()
    FigCurvature.savefig(os.path.join(results_dir,"CurvaturesComparison_" + title + ".pdf"),format = 'pdf',bbox_inches='tight')